import torch

device = 'cuda'
batch_size = 16

global_graph_memory_pool = None
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
states = torch.ones(batch_size, device=device)

with torch.cuda.graph(graph, 
        pool=global_graph_memory_pool,
        stream=stream):
    ones = torch.ones(states.shape, device=device)
    states += ones

global_graph_memory_pool = graph.pool()

for _ in range(10):
    graph.replay()
    torch.cuda.synchronize()
    print(f'{states=}')